import json
from collections import Counter

import pandas as pd
import unicodedata
import numpy as np
import math
import jieba
from tqdm import tqdm
import re


class UniVoc:
    def __init__(self,flag=None):
        """
        初始化UniVoc类

        参数:
        multi_token_size (int): 多字符词汇最大数量
        jieba_dict (str): jieba分词的自定义词典路径
        """

        # self.multi_tokens = []  # 存储多字符词汇（长度>1）
        # self.multi_token_size = multi_token_size

        # 初始化jieba分词器
        # if jieba_dict:
        #     jieba.load_userdict(jieba_dict)
        self.tokenizer = jieba.Tokenizer()
        if flag:
            self.voc = []

            self.voc_x2id = {}
            self.voc_id2x = {}
            self.single_char_map = {}  # 单个字符到token对的映射
            self.token_pair_char_map = {}  # token对到单个字符的映射
            # 初始化词汇表
            self._init_vocabulary()
        else:
            self.voc_x2id = pd.read_pickle("voc_x2id.pkl")
            self.voc_id2x = pd.read_pickle("voc_id2x.pkl")
            self.token_pair_char_map = pd.read_pickle("token_pair_char_map.pkl")
            self.single_char_map = pd.read_pickle("single_char_map.pkl")
            self.voc_size = len(self.voc_x2id)

            # # 8. 保存映射
            # pd.to_pickle(self.voc_id2x, "voc_id2x.pkl")
            # pd.to_pickle(self.voc_x2id, "voc_x2id.pkl")
            #

    def is_chinese(self, char):
        chinese_pattern = re.compile(r'[\u4e00-\u9fa5]')
        return chinese_pattern.match(char) is not None

    def is_meaningful(self, char):
        """严格定义：已分配 + 非控制字符"""
        try:

            cat = unicodedata.category(char)
            return not (cat.startswith('C') and cat not in ['Co', 'Cn'])
        except:
            return False

    def _get_meaningful_chars(self):
        """获取有意义字符列表"""
        meaningful_chars = []
        for code in range(0x10000):  # 基本平面
            char = chr(code)
            if self.is_meaningful(char):
                meaningful_chars.append(char)
        return meaningful_chars[:-1]  # 移除最后一个

    def _find_min_sum_integer(self, S):
        """
        求解当 m*n = S 时，m+n 的最小值
        返回: (m, n, min_sum)
        """
        if not isinstance(S, int) or S <= 0:
            raise ValueError("S 必须是正整数")

        min_sum = S + 1
        best_pair = (1, S)
        sqrt_S = int(math.isqrt(S))

        for m in range(1, sqrt_S + 1):
            if S % m == 0:
                n = S // m
                current_sum = m + n
                if current_sum < min_sum:
                    min_sum = current_sum
                    best_pair = (m, n)

        return best_pair[0], best_pair[1], min_sum

    def _init_vocabulary(self):
        """初始化词汇表结构"""
        # 1. 获取有意义字符
        meaningful_chars = self._get_meaningful_chars()


        voc = []
        voc_data = pd.read_pickle("voc_all.pkl")

        en, zh, zhs, ens = voc_data["en"], voc_data["zh"], voc_data["zhs"], voc_data["ens"]

        # 排序
        en = sorted(en, key=lambda x: en[x], reverse=True)
        ens = sorted(ens, key=lambda x: ens[x], reverse=True)
        zh = sorted(zh, key=lambda x: zh[x], reverse=True)
        zhs = sorted(zhs, key=lambda x: zhs[x], reverse=True)
        voc += en[:300]
        voc += zh[:4000]
        voc += zhs[:4000]
        voc += ens[:4000]

        meaningful_chars+=en[300:]
        meaningful_chars+=zh[4000:]
        meaningful_chars+=zhs[4000:]
        meaningful_chars+=ens[4000:]
        voc=list(set(voc))
        meaningful_chars=list(set(meaningful_chars) - set(voc))

        S = len(meaningful_chars)



        # 2. 计算最佳矩阵维度
        m, n, min_sum = self._find_min_sum_integer(S)
        print(f"字符数: {S}, 矩阵维度: {m} x {n}, 最小和: {min_sum}")

        # 3. 构建单字符映射
        s_tokens = [f"s_{i}" for i in range(m)]
        e_tokens = [f"e_{j}" for j in range(n)]

        # 打乱字符顺序
        np.random.shuffle(meaningful_chars)

        # 创建映射: 字符 -> (s_token, e_token)
        char_index = 0
        for i in range(m):
            for j in range(n):
                if char_index >= S:
                    break
                char = meaningful_chars[char_index]
                self.single_char_map[char] = (s_tokens[i], e_tokens[j])
                self.token_pair_char_map[(s_tokens[i], e_tokens[j])] = char
                char_index += 1

        # 4. 构建基础词汇表
        # 特殊标记
        special_tokens = [
            "<|pad|>", "<|im_start|>", "<|im_end|>", "<|think|>",
            "<|end_think|>", "<|user|>", "<|agent|>", "<|system|>",
            "<|func|>", "<|args|>", "<|unk|>", "<|space|>"
        ]

        # 添加单字符token
        self.voc = special_tokens + s_tokens + e_tokens+voc

        # 5. 添加多字符词汇


        # 6. 打乱词汇表（特殊标记除外）
        special_count = len(special_tokens)
        non_special = self.voc[special_count:]
        np.random.shuffle(non_special)
        self.voc = special_tokens + non_special

        # 7. 创建映射字典
        self.voc_x2id = {token: idx for idx, token in enumerate(self.voc)}
        self.voc_id2x = {idx: token for idx, token in enumerate(self.voc)}

        # 8. 保存映射
        pd.to_pickle(self.voc_id2x, "voc_id2x.pkl")
        pd.to_pickle(self.voc_x2id, "voc_x2id.pkl")
        pd.to_pickle(self.single_char_map, "single_char_map.pkl")
        pd.to_pickle(self.token_pair_char_map, "token_pair_char_map.pkl")
        print(f"词汇表大小: {len(self.voc)}")

    def encode(self, text):
        """
        将文本编码为token ID列表

        使用jieba分词后编码:
        1. 优先匹配多字符词汇
        2. 单个字符使用两个token编码
        """
        # 使用jieba进行分词
        words = self.tokenizer.lcut(text)

        token_ids = []

        # 遍历分词结果
        for word in words:
            # 空词跳过
            if not word.strip():
                if word.isspace():
                    token_ids.append(self.voc_x2id["<|space|>"])
                continue

            # 尝试作为多字符词汇匹配
            if word in self.voc_x2id:
                token_ids.append(self.voc_x2id[word])
            else:
                # 将词汇拆分为字符处理
                for char in word:
                    # 处理特殊字符
                    if char.isspace():
                        token_ids.append(self.voc_x2id["<|space|>"])
                    # 处理单字符
                    elif char in self.single_char_map:
                        s_token, e_token = self.single_char_map[char]
                        token_ids.append(self.voc_x2id[s_token])
                        token_ids.append(self.voc_x2id[e_token])
                    elif char in self.voc_x2id:
                        token_ids.append(self.voc_x2id[char])
                    # 处理未知字符
                    else:
                        token_ids.append(self.voc_x2id["<|unk|>"])

        return token_ids

    def decode(self, token_ids):
        """
        将token ID列表解码为文本

        策略:
        1. 检查连续的两个token是否可以组合成单个字符
        2. 否则按单个token解码
        """
        tokens = []
        i = 0
        while i < len(token_ids):
            # 获取当前token
            current_id = token_ids[i]
            current_token = self.voc_id2x.get(current_id, "<|unk|>")


            # 检查特殊标记
            if current_token == "<|space|>":
                tokens.append("\n")
                i += 1
                continue

            # 检查是否是s_token前缀
            if current_token.startswith("s_") and (i + 1) < len(token_ids):
                next_id = token_ids[i + 1]
                next_token = self.voc_id2x.get(next_id, "<|unk|>")

                # 检查是否是有效的token对
                if next_token.startswith("e_"):
                    token_pair = (current_token, next_token)
                    if token_pair in self.token_pair_char_map:
                        tokens.append(self.token_pair_char_map[token_pair])
                        i += 2  # 消耗两个token
                        continue

            # 如果不是有效的组合，直接添加当前token
            tokens.append(current_token)
            i += 1

        return "".join(tokens)

    def split_voc(self):
        # chinese_clip = Counter()
        # chinese_clips = Counter()
        # with open("pretrain_hq.jsonl", "r", encoding="utf-8") as f:
        #     data = f.readlines()
        # for line in tqdm(data):
        #     line = json.loads(line.strip())
        #     line = line["text"].replace("<|im_start|>", " ").replace("<|im_end|>", " ")
        #     chinese_clip.update(Counter(list(line)))
        #     chinese_clips.update(Counter(jieba.lcut(line)))
        english_clip = Counter()
        with open("rank_317.jsonl", "r", encoding="utf-8") as f:
            data = f.readlines()
        for line in tqdm(data):
            line = json.loads(line.strip())
            line = line["text"].replace("<|im_start|>", " ").replace("<|im_end|>", " ")
            english_clip.update(Counter(list(line)))
        # pd.to_pickle({"en": english_clip, "zh": chinese_clip, "zhs": chinese_clips}, "voc_single.pkl")


# 使用示例
if __name__ == "__main__":
    # ens = pd.read_pickle("voc.pkl")
    # voc_data1 = pd.read_pickle("voc_single.pkl")
    # en, zh, zhs = voc_data1["en"], voc_data1["zh"], voc_data1["zhs"]
    # pd.to_pickle({"en":en,"zh":zh,"ens":ens,"zhs":zhs}, "voc_all.pkl")
    # 初始化词汇表
    univoc = UniVoc()  # 可选自定义词典
    # univoc.split_voc()

    # 测试文本
    test_text = "自然语言处理（NLP）是人工智能的重要分支。"

    # 编码
    encoded_ids = univoc.encode(test_text)
    print(f"编码结果: {encoded_ids}")

    # 解码
    decoded_text = univoc.decode(encoded_ids)
    print(f"解码结果: {decoded_text}")

    print("原始文本:", test_text)
    print("解码文本:", decoded_text)
    print("匹配结果:", "成功" if test_text == decoded_text else "失败")
#